
include("../src/qmcmary.jl")
using ..qmcmary

using Test
using ReverseDiff: jacobian, jacobian!
using LinearAlgebra

BLAS.set_num_threads(1)

function sgn_scratch(ss::ScrollSVD{T}) where T
    siz = size(ss.B[end])
    ptr = findfirst(ss.L)
    if isnothing(ptr)
        VL = Diagonal(ones(siz[1]))
        DL = VL
        UL = VL
        UR, DR, VR = ss.F[end].U, Diagonal(ss.F[end].S), ss.F[end].Vt
    elseif ptr == 1
        VL, DL, UL = ss.F[1].U, Diagonal(ss.F[1].S), ss.F[1].Vt
        UR = Diagonal(ones(siz[1]))
        DR = UR
        VR = UR
        ##上面这个数值稳定性会突然出问题
        ##用下面这个
        ##VL DL UL UR=I DR=I VR=I -> I I I UR=VL DR=DL VR=UL
        #VL = Diagonal(ones(siz[1]))
        #DL = VL
        #UL = VL
        #UR, DR, VR = ss.F[1].U, Diagonal(ss.F[1].S), ss.F[1].Vt
    else
        VL, DL, UL = ss.F[ptr].U, Diagonal(ss.F[ptr].S), ss.F[ptr].Vt
        UR, DR, VR = ss.F[ptr-1].U, Diagonal(ss.F[ptr-1].S), ss.F[ptr-1].Vt
    end
    #gtt = inv(Diagonal(ones(siz[1]))+UR*DR*VR*VL*DL*UL)
    #M = inv(UL*UR) + DR*(VR*VL)*DL
    #Fm = svd(M)
    #gtt = inv(Fm.Vt*UL)*inv(Diagonal(Fm.S))*inv(UR*Fm.U)
    #
    DLS = Diagonal(ones(Float64, siz[1]))
    DLB = Diagonal(ones(Float64, siz[1]))
    DRS = Diagonal(ones(Float64, siz[1]))
    DRB = Diagonal(ones(Float64, siz[1]))
    for i in Base.OneTo(siz[1])
        if DL[i, i] > 1.0
            DLB[i, i] = DL[i, i]
            DLS[i, i] = 1.0
        else
            DLS[i, i] = DL[i, i]
            DLB[i, i] = 1.0
        end
        if DR[i, i] > 1.0
            DRB[i, i] = DR[i, i]
            DRS[i, i] = 1.0
        else
            DRS[i, i] = DR[i, i]
            DRB[i, i] = 1.0
        end
    end
    #
    M = inv(DRB)*adjoint(UL*UR)*inv(DLB) + DRS*(VR*VL)*DLS
    Fm = svd(M, alg=LinearAlgebra.QRIteration())
    #ML = adjoint(UL)*inv(DLB)*adjoint(Fm.Vt)
    #MR = adjoint(Fm.U)*inv(DRB)*adjoint(UR)
    #gtt = ML*inv(Diagonal(Fm.S))*MR
    #增加计算phase
    sgn = det(Fm.Vt*UL)*det(UR*Fm.U)
    sgn = log(abs(sgn))
    for sval in Fm.S
        sgn += log(sval)
    end
    for sval in diag(DLB)
        sgn += log(sval)
    end
    for sval in diag(DRB)
        sgn += log(sval)
    end
    return sgn
end


function apply_tp()
    hk = rand(9, 9)
    hk = hk + adjoint(hk)
    ehk = exp(0.1*hk)
    tp = zeros(9, 9)
    i1 = 2
    i2 = 4
    tp[i1, i2] = 0.25
    tp[i2, i1] = 0.25
    etp = exp(0.1*tp)
    println(etp)
    mul1 = etp*ehk
    println(mul1[i1, :])
    println(ehk[i1, :])
    println(ehk[i1, :]*cosh(0.1*0.25)+ehk[i2, :]*sinh(0.1*0.25))
    println(mul1[i2, :])
    println(ehk[i2, :])
    println(ehk[i2, :]*cosh(0.1*0.25)+ehk[i1, :]*sinh(0.1*0.25))
end


function pbptp_example()
    L = 4
    Nx = L^2
    nx = zeros(Nx)
    ny = zeros(Nx)
    nz = ones(Nx)
    #
    tp = 0.25
    hk = lattice_tprim_square(ComplexF64, L, -1.0+0.0im, tp+0.0im)
    tpidx = zeros(Int, Nx)
    tpval = tp*ones(Nx)
    ucmap = zeros(Int, L, L)
    for ux in Base.OneTo(L); for uy in Base.OneTo(L)
        ucmap[ux, uy] = L*(uy-1) + ux
    end; end
    for ux in Base.OneTo(L); for uy in Base.OneTo(L)
        Aidx = ucmap[ux, uy]
        Bidx = ucmap[@↻(L, ux+1), @↻(L, uy+1)]
        tpidx[Aidx] = Bidx
        tpval[Aidx] = tp
    end; end
    #
    Ui = 1*ones(Nx)
    Nt = 50
    Ng = 5
    hscfg = Matrix{Int}(undef, Nt, Nx)
    for bi in Base.OneTo(Nt)
        cfg = 2(round.(rand(Nx)) .- 0.5)
        hscfg[bi, :] .= cfg
    end
    sp = default_splitting(Nt, hk, Ui; Z2=true)
    sslen, bmats, allbmats1, ss = initialize_SS(Nt, Ng, sp, hscfg, nx, ny, nz)
    #
    ni = 2
    #
    tapemap = pbptp_calc_tapemap(sp, tpidx, tpval)
    adf_r, adf_i = pbptp_calc_xi(sp, hscfg, ni, tpidx, tpval; tapemap=tapemap)
    #更改位置i上的mu
    hk[ni, tpidx[ni]] += 0.001
    hk[tpidx[ni], ni] += 0.001
    #hk[ni+Nx, tpidx[ni]+Nx] += 0.001
    #hk[tpidx[ni]+Nx, ni+Nx] += 0.001 
    sp = default_splitting(Nt, hk, Ui; Z2=true)
    sslen, bmats, allbmats2, ss = initialize_SS(Nt, Ng, sp, hscfg, nx, ny, nz)
    #
    rdiff = real(allbmats2[1]) - real(allbmats1[1])
    #println(imag(allbmats2[1]) - imag(allbmats1[1]))
    println("r1", rdiff[ni, :])
    println("adf", adf_r[1, 1, :])
    println("r1 2", rdiff[tpidx[ni], :])
    println("adf 2", adf_r[1, 3, :])
    println("r dn",rdiff[ni+Nx, :])
    println("adf dn", adf_r[1, 2, :])
    println("r dn2",rdiff[tpidx[ni]+Nx, :])
    println("adf dn2", adf_r[1, 4, :])
    println(hscfg[1, ni], " ", hscfg[1, tpidx[ni]])
    #println(rdiff[ni+1+Nx, :])
    #
    #println(adf_r[1, 4, :])
    #println(adf_i[1, 1, :])
    #println(adf_i[1, 2, :])
end


function tpbar_example()
    #通过一个[t1, ..., tn]在确定的hscfg下计算出sign和 d ln s / dti
    #微调ti计算sign，验证数值正确性
    L = 3
    Nx = L^2
    nx = zeros(Nx)
    ny = zeros(Nx)
    nz = ones(Nx)
    #
    tp = 0.25
    hk = lattice_tprim_square(ComplexF64, L, -1.0+0.0im, tp+0.0im)
    tpidx = zeros(Int, Nx)
    tpval = tp*ones(Nx)
    ucmap = zeros(Int, L, L)
    for ux in Base.OneTo(L); for uy in Base.OneTo(L)
        ucmap[ux, uy] = L*(uy-1) + ux
    end; end
    for ux in Base.OneTo(L); for uy in Base.OneTo(L)
        Aidx = ucmap[ux, uy]
        Bidx = ucmap[@↻(L, ux+1), @↻(L, uy+1)]
        tpidx[Aidx] = Bidx
        tpval[Aidx] = tp
    end; end
    #
    Ui = 1*ones(Nx)
    Nt = 50
    Ng = 5
    hscfg = Matrix{Int}(undef, Nt, Nx)
    for bi in Base.OneTo(Nt)
        cfg = 2(round.(rand(Nx)) .- 0.5)
        hscfg[bi, :] .= cfg
    end
    sp = default_splitting(Nt, hk, Ui; Z2=true)
    sslen, bmats, allbmats, ss = initialize_SS(Nt, Ng, sp, hscfg, nx, ny, nz)
    logS1 = sgn_scratch(ss)
    #
    tapemap = pbptp_calc_tapemap(sp, tpidx, tpval)
    tpbar = meas_gradtp(ss, allbmats, sp, hscfg, tpidx, tpval; tapemap=tapemap)
    #
    #println(nybar)
    #println(nzbar)
    #
    admat = zeros(Float64, Nx, 1)
    ndmat = zeros(Float64, Nx, 1)
    for xi in Base.OneTo(Nx)
        hk[xi, tpidx[xi]] += 0.001
        hk[tpidx[xi], xi] += 0.001
        #hk[xi+Nx, tpidx[xi]+Nx] += 0.001
        #hk[tpidx[xi]+Nx, xi+Nx] += 0.001
        #println(diag(mui))
        sp2 = default_splitting(Nt, hk, Ui; Z2=true)
        sslen, bmats, allbmats, ss2 = initialize_SS(Nt, Ng, sp2, hscfg, nx, ny, nz)
        logS2 = sgn_scratch(ss2)
        #d log S = - d logabs(w)
        #所以这里需要加一个负号
        #println(-logS2 + logS1)
        nd = -logS2 + logS1
        ad = 0.001*tpbar[xi]
        #println(nd, " ", ad)
        admat[xi, 1] = ad
        ndmat[xi, 1] = nd
        #
        hk[xi, tpidx[xi]] -= 0.001
        hk[tpidx[xi], xi] -= 0.001
        #hk[xi+Nx, tpidx[xi]+Nx] -= 0.001
        #hk[tpidx[xi]+Nx, xi+Nx] -= 0.001
    end
    println(admat, ndmat, admat-ndmat)
    @testset "-∂log(s)/∂x" begin
        @test all(isapprox.(ndmat, admat, rtol=1e-2, atol=2e-4))
    end
end


tpbar_example()
#pbptp_example()
#apply_tp()

